import warnings
import logging
import os

# 抑制警告
warnings.filterwarnings('ignore')
logging.getLogger('omnigibson').setLevel(logging.ERROR)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'  # 使用镜像站点

# 如果需要更细致的控制，可以针对特定警告
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', message='.*particleRemover.*')
warnings.filterwarnings('ignore', message='.*ClothStateMixin.*')

ALGO_NAME = 'diffusion_with_rl_top_img_single_arm'

import os
import argparse
import random
from distutils.util import strtobool

os.environ["OMP_NUM_THREADS"] = "1"

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

import datetime
from datetime import datetime
from collections import defaultdict
import sys


from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from nets.diffusion_policy.conditional_unet1d_colab import ConditionalUnet1D
import os
import click
import yaml
import torch
import numpy as np
# from omnigibson.envs import DataPlaybackWrapper
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import DiffusionTransformerHybridImagePolicy

import os
import torch
import numpy as np
import omnigibson as og
# from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import DiffusionTransformerHybridImagePolicy
# from omnigibson.envs import DataPlaybackWrapper
import yaml
import time
import os
import json
import math
import numpy as np
import torch as th
import yaml
from scipy.spatial.transform import Rotation as R
from multiprocessing import shared_memory, Queue, Event
import faulthandler
import omnigibson as og
import omnigibson.lazy as lazy
from omnigibson.macros import gm
from omnigibson.utils.ui_utils import KeyboardEventHandler
# from omnigibson.envs import DataPlaybackWrapper
import torch.nn.functional as F
import sys
import os
import yaml
import time
from multiprocessing.managers import SharedMemoryManager
import click
import cv2
import numpy as np
import torch
import dill
import hydra
from omegaconf import OmegaConf
import sys
sys.path.append("/home/admin01/project/") 
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
import h5py
import sys
sys.path.append("/home/admin01/project/DexDiffusionPolicy")
import yaml  # 用于读取配置文件
import dill  # 用于加载模型
import hydra  # 用于加载diffusion policy
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
# Add Benchmark_results path using relative path from current project
current_project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # policy_decorator dir
project_dir = os.path.dirname(current_project_dir)  # /home/admin01/project
benchmark_path = os.path.join(project_dir, "Benchmark")
sys.path.append(benchmark_path)
sys.path.append("/home/admin01/project/policy_decorator")
from policy_decorator.utils.profiling import NonOverlappingTimeProfiler

# Now we can import from Benchmark_results
from Benchmark_results.evaluate.env_aug import EnvAugmentor

# 保留这一行作为默认配置文件路径
RL_TRAIN_CONFIG_PATH = "/home/admin01/project/path/to/policy_config.yaml"

def get_domain_info(args):
    """读取并保存域信息和训练参数"""
    domain_info = {}
    
    # 读取 env_aug.yaml
    with open(args.paths.env_augmentor_cfg, "r") as f:
        env_aug_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['env_augmentor'] = env_aug_config
    
    # 从 env_aug.yaml 中获取物体配置文件路径
    obj_config_path = env_aug_config['obj_config_path']
    
    # 读取物体配置文件
    with open(obj_config_path, "r") as f:
        obj_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['obj_config'] = obj_config
    
    # 读取环境配置文件
    with open(args.paths.env_cfg, "r") as f:
        env_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['env_config'] = env_config
        
    # 读取 RL 训练参数
    with open(RL_TRAIN_CONFIG_PATH, "r") as f:  # 保留这个常量路径
        rl_config = yaml.load(f, Loader=yaml.FullLoader)
        domain_info['rl_train_config'] = rl_config
    
    return domain_info

def setup_logging_dir(args):
    """创建基于时间戳的结果目录"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(args.paths.results_base, timestamp)
    os.makedirs(run_dir, exist_ok=True)
    return run_dir

def setup_env_augmentor(args):
    with open(args.paths.env_augmentor_cfg, "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    augmentor = EnvAugmentor(
        texture_dir=cfg["texture_dir"],
        default_texture_path=cfg["default_texture_path"],
        obj_config_path=cfg["obj_config_path"],
        position_offset=cfg["position_offset"],
        default_pos=cfg["default_pos"],
        light_color=cfg["light_color"],
        light_intensity=cfg["light_intensity"],
        default_obj_config_path=cfg["default_obj_config_path"],
    )
    return augmentor
def convert_to_json_serializable(obj):
    """将对象转换为可JSON序列化的格式"""
    if isinstance(obj, dict):
        return {k: convert_to_json_serializable(v) for k, v in obj.items()}
    elif hasattr(obj, '__dict__'):
        return {k: convert_to_json_serializable(v) for k, v in vars(obj).items()
                if not k.startswith('_')}
    elif isinstance(obj, (list, tuple)):
        return [convert_to_json_serializable(x) for x in obj]
    elif isinstance(obj, (int, float, str, bool, type(None))):
        return obj
    else:
        return str(obj) 
def parse_args():
    """Load and parse configuration from YAML file and command line arguments"""
    import argparse
    import yaml
    from pathlib import Path
    from types import SimpleNamespace

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=RL_TRAIN_CONFIG_PATH,
                       help="path to config file")
    parser.add_argument("--exp-name", type=str, help="override experiment name")
    parser.add_argument("--seed", type=int, help="override random seed")
    parser.add_argument("--resume-path", type=str, help="path to checkpoint to resume training from")
    args_cmd = parser.parse_args()

    # Load config file
    with open(args_cmd.config, 'r') as f:
        args_yaml = yaml.safe_load(f)

    # 将paths配置转换为SimpleNamespace对象以便于访问
    paths_dict = args_yaml.pop('paths', {})
    args_yaml['paths'] = SimpleNamespace(**paths_dict)

    # Convert yaml to argparse namespace
    args = argparse.Namespace(**args_yaml)

    # Override with command line arguments if provided
    if args_cmd.exp_name is not None:
        args.exp_name = args_cmd.exp_name
    if args_cmd.seed is not None:
        args.seed = args_cmd.seed
    if args_cmd.resume_path is not None:
        args.resume_path = args_cmd.resume_path

    # Post-processing
    if args.buffer_size is None:
        args.buffer_size = args.total_timesteps
    args.buffer_size = min(args.total_timesteps, args.buffer_size)
    args.num_eval_envs = min(args.num_eval_envs, args.num_eval_episodes)
    
    # Validations
    assert args.num_eval_episodes % args.num_eval_envs == 0
    assert args.training_freq % args.num_envs == 0
    assert (args.training_freq * args.utd).is_integer()
    
    return args

###############ours####################
def load_diffusion_policy_model(checkpoint_path):
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context
    
    import requests
    from requests.packages.urllib3.exceptions import InsecureRequestWarning
    requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
    
    # 修改 huggingface_hub 的设置
    import os
    os.environ['HF_HUB_DISABLE_SSL_VERIFICATION'] = '1'
    ckpt_path = checkpoint_path
    payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill)
    cfg = payload['cfg']
    
    # 打印目标类路径以进行调试
    # print(f"Target class path: {cfg._target_}")
    
    # 如果需要，手动修改目标类路径
    # cfg._target_ = "your.custom.path.to.ModelClass"
    
    cls = hydra.utils.get_class(cfg._target_)
    workspace = cls(cfg)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)

    if 'diffusion' in cfg.name:
        # diffusion model
        policy: BaseImagePolicy
        policy = workspace.model
        # if cfg.training.use_ema:
        #     policy = workspace.ema_model

        device = torch.device('cuda')
        policy.eval().to(device)

        # set inference params
        policy.num_inference_steps = 16 # DDIM inference iterations
        policy.n_action_steps = 8
    return policy

def load_environment(args):
    with open(args.paths.env_cfg, "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    env = og.Environment(configs=cfg["sim"])
    move_to_initial_position(env)
    ball = env.scene.object_registry("name", "ball0")
    obj_start_pos=ball.get_position()
    env._task.reset_for_new_task(obj_start_pos) # for reward compute
    env._task.calculate_init_reward(env,obj_start_pos)
    return env

def env_reset(env, env_augmentor):
    env.reset()
    env_augmentor.restore2defaultstate(env)
    move_to_initial_position(env)
    env_augmentor.apply_random_aug(env)
    ball = env.scene.object_registry("name", "ball0")
    obj_start_pos=ball.get_position()
    env._task.reset_for_new_task(obj_start_pos) # for reward compute
    env._task.calculate_init_reward(env,obj_start_pos)




def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            layer_init(nn.Linear(256, 1), std=0.01),
        )

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        return self.net(x)

LOG_STD_MAX = 2
LOG_STD_MIN = -20
def get_action_space_boundaries():
    # 从URDF中提取的关节限制
    # Format: [lower, upper]
    joint_limits = {
        "joint11": [0.0, 1.3089969],  # 拇指关节1
        "joint12": [0.0, 0.6806784],  # 拇指关节2
        "joint13": [0.0, 0.8028515],  # 拇指关节3
        "joint14": [0.0, 0.5934119],  # 拇指关节4
        
        # 其他手指的第一关节(MCP)
        "joint21": [0.0, 1.5358897],  # 食指关节1
        "joint31": [0.0, 1.5358897],  # 中指关节1
        "joint41": [0.0, 1.5358897],  # 无名指关节1
        "joint51": [0.0, 1.5358897],  # 小指关节1
        
        # 其他手指的第二关节(PIP)
        "joint22": [0.0, 1.9896753],  # 食指关节2
        "joint32": [0.0, 1.9896753],  # 中指关节2
        "joint42": [0.0, 1.9896753],  # 无名指关节2
        "joint52": [0.0, 1.9896753]   # 小指关节2
    }

    # 创建完整的动作空间下限和上限
    l = np.zeros(18)
    h = np.zeros(18)

    # 设置EEF delta位置部分: 范围为-0.1到0.1 #need to check
    l[0:3] = -0.1
    h[0:3] = 0.1

    # 设置EEF absolute方向部分: 范围为-1.5到1.5 #need to check
    l[3:6] = -1.5
    h[3:6] = 1.5

    # 设置手指关节部分: 使用URDF中的限制
    # 拇指关节(4个)
    l[6] = joint_limits["joint11"][0]
    h[6] = joint_limits["joint11"][1]
    l[7] = joint_limits["joint12"][0]
    h[7] = joint_limits["joint12"][1]
    l[8] = joint_limits["joint13"][0]
    h[8] = joint_limits["joint13"][1]
    l[9] = joint_limits["joint14"][0]
    h[9] = joint_limits["joint14"][1]

    # 其他手指关节(8个): 每个手指2个关节
    joint_names = ["joint21", "joint22", "joint31", "joint32", "joint41", "joint42", "joint51", "joint52"]
    for i, name in enumerate(joint_names):
        l[10+i] = joint_limits[name][0]
        h[10+i] = joint_limits[name][1]

    # 转换为PyTorch张量
    l_tensor = torch.tensor(l, dtype=torch.float32)
    h_tensor = torch.tensor(h, dtype=torch.float32)

    # 计算action_scale和action_bias
    action_scale = (h_tensor - l_tensor) / 2.0
    action_bias = (h_tensor + l_tensor) / 2.0
    print("action_scale:",action_scale)
    print("action_bias:",action_bias)
    return action_scale, action_bias
class Actor(nn.Module):
    def __init__(self, env, args):
        super().__init__()
        # object pose ---> state 
        proprio_dim = (19 + 19 + 12 + 12 + 3 + 4 +3)  # joint_qpos + joint_qvel + gripper_qpos + gripper_qvel + eef_pos + eef_quat + obj_pos
        obs_dim = proprio_dim
        action_dim = 18
        input_dim = obs_dim if args.actor_input == 'obs' else obs_dim + action_dim #69
        # print("input_dim:",input_dim)
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        # 使用明确的动作维度
        self.fc_mean = layer_init(nn.Linear(256, action_dim), std=0.01)
        self.fc_logstd = layer_init(nn.Linear(256, action_dim), std=0.01)
        
        # 是否使用动作空间映射
        self.use_action_scaling = getattr(args, 'use_action_scaling', True)
        
        # 如果使用动作空间映射，计算action_scale和action_bias
        if self.use_action_scaling:
            action_scale, action_bias = get_action_space_boundaries()
            self.register_buffer("action_scale", action_scale)
            self.register_buffer("action_bias", action_bias)

    def forward(self, x):
        x = self.backbone(x) 
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_action(self, x, deterministic=False):
        try:
            # 检查输入是否有NaN
            if torch.isnan(x).any():
                print("警告: 输入包含NaN值")
                x = torch.nan_to_num(x, nan=0.0)
            
            # 获取均值和对数标准差
            mean, log_std = self(x)
            
            # 检测并处理NaN值
            if torch.isnan(mean).any():
                print("警告: 均值包含NaN值")
                mean = torch.nan_to_num(mean, nan=0.0)
            
            if torch.isnan(log_std).any():
                print("警告: 对数标准差包含NaN值")
                log_std = torch.nan_to_num(log_std, nan=-10.0)
            
            std = log_std.exp()
            
            # 检查标准差是否有NaN
            if torch.isnan(std).any():
                print("警告: 标准差包含NaN值")
                std = torch.nan_to_num(std, nan=1.0)
            
            normal = torch.distributions.Normal(mean, std)
            
            if deterministic:
                x_t = mean
            else:
                # 使用重参数化技巧采样
                x_t = normal.rsample()
            
            # 检查采样值是否有NaN
            if torch.isnan(x_t).any():
                print("警告: 采样值包含NaN")
                x_t = torch.nan_to_num(x_t, nan=0.0)
                
            y_t = torch.tanh(x_t)
            
            # 应用动作缩放(如果启用)
            if self.use_action_scaling:
                action = y_t * self.action_scale + self.action_bias
                # 修正log_prob计算
                log_prob = normal.log_prob(x_t)
                # 确保分母不为零
                denom = self.action_scale * (1 - y_t.pow(2) + 1e-6)
                log_prob -= torch.log(denom)
            else:
                action = y_t
                log_prob = normal.log_prob(x_t)
                # 确保分母不为零
                log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
            
            # 合并所有维度的log_prob
            log_prob = log_prob.sum(1, keepdim=True)
            
            # 检查log_prob是否有NaN
            if torch.isnan(log_prob).any():
                print("警告: log_prob包含NaN值")
                log_prob = torch.nan_to_num(log_prob, nan=0.0)
            
            # 处理均值输出(用于确定性动作)
            mean_tanh = torch.tanh(mean)
            if self.use_action_scaling:
                mean_action = mean_tanh * self.action_scale + self.action_bias
            else:
                mean_action = mean_tanh
            
            return action, log_prob, mean_action
        
        except Exception as e:
            print(f"get_action方法出错: {e}")
            # 返回零向量作为应急措施
            batch_size = x.shape[0]
            action_dim = self.action_dim if hasattr(self, 'action_dim') else x.shape[1]
            device = x.device
            return (
                torch.zeros((batch_size, action_dim), device=device),
                torch.zeros((batch_size, 1), device=device),
                torch.zeros((batch_size, action_dim), device=device)
            )
    def get_eval_action(self, x):
        """用于评估的确定性动作"""
        x = self.backbone(x)
        mean = self.fc_mean(x)
        mean = torch.tanh(mean)
        
        # 根据设置决定是否应用动作空间映射
        if self.use_action_scaling:
            action = mean * self.action_scale + self.action_bias
        else:
            action = mean
            
        return action

    def to(self, device):
        return super().to(device)


def collect_episode_info(infos, result=None):
    if result is None:
        result = defaultdict(list)
    if "final_info" in infos: # infos is a dict
        indices = np.where(infos["_final_info"])[0] # not all envs are done at the same time
        for i in indices:
            info = infos["final_info"][i] # info is also a dict
            ep = info['episode']
            print(f"global_step={global_step}, ep_return={ep['r'][0]:.2f}, ep_len={ep['l'][0]}, success={info['success']}")
            result['return'].append(ep['r'][0])
            result['len'].append(ep["l"][0])
            result['success'].append(info['success'])
    return result

def is_ms1_env(env_id):
    return 'OpenCabinet' in env_id or 'MoveBucket' in env_id or 'PushChair' in env_id
def process_image(img):
        # 转换为torch tensor并保持float类型
        if isinstance(img, torch.Tensor):
            img = img[..., :3].float()  # 只保留RGB通道
        else:
            img = torch.from_numpy(img[..., :3]).float()  # 转换为tensor并保留RGB通道
        
        # 调整通道顺序 [H, W, C] -> [C, H, W]
        img = img.permute(2, 0, 1)  
        # img = img.permute(2, 1, 0)  # 调整通道顺序
        
        # 检查输入范围并归一化
        max_val = img.max().item()
        is_normalized = max_val <= 1.0
        
        if not is_normalized:
            img = img / 255.0  # 归一化到[0,1]
        
        # 调整尺寸
        img = F.interpolate(
            img.unsqueeze(0),  # 添加batch维度
            size=(224, 224),   # 调整到模型期望的尺寸
            mode='bilinear',
            align_corners=False
        ).squeeze(0)  # 移除batch维度
        
        # print(f"Final processed image shape: {img.shape}")
        return img
def preprocess_observation(obs, model):
    # print("obs:",obs)
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    proprio = obs_dict['panda']['proprio']
    # # 构建当前观察字典
    current_obs = {
        # 'panda_camera_img': process_image(obs_dict['panda']['panda:wist:Camera:0']['rgb']),
        # 'external_sensor2_img': process_image(obs_dict['external']['external_sensor2']['rgb']),
        'external_sensor3_img': process_image(obs_dict['external']['external_sensor3']['rgb']),
        'joint_qpos': proprio[:19],
        'joint_qvel': proprio[19:38],
        'gripper_0_qpos': proprio[38:50],
        'gripper_0_qvel': proprio[50:62],
        'eef_0_pos': proprio[62:65],
        'eef_0_quat': proprio[65:69]
    }
    obs_dict = {
        k: torch.stack([current_obs[k]], dim=0)  # 只使用当前观察
        for k in current_obs.keys()
    }
    
    # 添加批次维度
    obs_dict = {k: v.unsqueeze(0) for k, v in obs_dict.items()}
    
    # return obs_dict
    return obs_dict


def extract_proprio(env,obs):
    """从观测字典中提取proprioception信息和obj state信息"""
    obs_dict = obs if isinstance(obs, dict) else obs[0]
    proprio_full = obs_dict['panda']['proprio']
    ball=env.scene.object_registry("name", "ball0")
    obj_pos=ball.get_position()
    proprio = np.concatenate([
        proprio_full[:19],    # joint_qpos
        proprio_full[19:38],  # joint_qvel  
        proprio_full[38:50],  # gripper_0_qpos
        proprio_full[50:62],  # gripper_0_qvel
        proprio_full[62:65],  # eef_0_pos
        proprio_full[65:69],  # eef_0_quat
        obj_pos
    ])
    # print("proprio:",proprio.shape)
    # input()
    return proprio

def load_checkpoint(path, res_actor, qf1, qf2, qf1_target, qf2_target, log_sac_alpha, 
                   actor_optimizer, q_optimizer, a_optimizer=None, policy_lr=1.0e-4, q_lr=1.0e-4,
                   load_global_step=True):
    """
    加载检查点，恢复模型和优化器状态，但分别重置策略和Q网络的学习率
    
    Args:
        path: 检查点路径
        res_actor: 残差策略网络
        qf1, qf2: Q网络
        qf1_target, qf2_target: 目标Q网络
        log_sac_alpha: 熵系数对数
        actor_optimizer: 策略优化器
        q_optimizer: Q网络优化器
        a_optimizer: 熵系数优化器（如果使用自动调整）
        policy_lr: 策略网络学习率
        q_lr: Q网络学习率
        load_global_step: 是否加载检查点中的全局步数
    """
    if os.path.exists(path):
        print(f"正在加载检查点: {path}")
        checkpoint = torch.load(path)
        
        # 加载模型权重
        res_actor.load_state_dict(checkpoint['res_actor'])
        
        # 加载Q网络权重（主网络和目标网络）
        if 'q1' in checkpoint:
            qf1.load_state_dict(checkpoint['q1'])
        if 'q2' in checkpoint:
            qf2.load_state_dict(checkpoint['q2'])
        
        # 加载目标网络
        if 'q1_target' in checkpoint:
            qf1_target.load_state_dict(checkpoint['q1_target'])
        else:
            qf1_target.load_state_dict(checkpoint['q1'])
            
        if 'q2_target' in checkpoint:
            qf2_target.load_state_dict(checkpoint['q2_target'])
        else:
            qf2_target.load_state_dict(checkpoint['q2'])
        
        # 加载熵系数
        if 'log_sac_alpha' in checkpoint and log_sac_alpha is not None:
            log_sac_alpha.data = checkpoint['log_sac_alpha']
            # print(f"log_sac_alpha: {log_sac_alpha.data}")
        
        # 恢复优化器状态并设置各自的学习率
        if 'actor_optimizer' in checkpoint:
            actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
            for param_group in actor_optimizer.param_groups:
                param_group['lr'] = policy_lr * 0.5  # 使用策略学习率的一半
                print(f"policy_lr: {param_group['lr']}")
        
        if 'q_optimizer' in checkpoint:
            q_optimizer.load_state_dict(checkpoint['q_optimizer'])
            for param_group in q_optimizer.param_groups:
                param_group['lr'] = q_lr * 0.5  # 使用Q网络学习率的一半
                print(f"q_lr: {param_group['lr']}")
        
        if a_optimizer is not None and 'a_optimizer' in checkpoint:
            a_optimizer.load_state_dict(checkpoint['a_optimizer'])
            for param_group in a_optimizer.param_groups:
                param_group['lr'] = q_lr * 0.1  # 熵系数通常使用Q网络学习率的十分之一
        
        # 获取检查点中的全局步数
        checkpoint_global_step = checkpoint.get('global_step', 0)
        
        if load_global_step:
            global_step = checkpoint_global_step
            print(f"加载成功! 当前步数: {global_step}")
            return True, global_step
        else:
            print(f"加载成功! 原始步数: {checkpoint_global_step}，重置为0")
            return True, 0
    
    print(f"检查点不存在: {path}")
    return False, 0


def generate_action(robot, arm_targets):
    """
    Generate a no-op action that will keep the robot still but aim to move the arms to the saved pose targets.

    Args:
        robot: The robot object to control.
        arm_targets: A dictionary mapping arm names to their target positions and orientations.

    Returns:
        th.Tensor: Action array for the robot.
    """
    action = th.zeros(robot.action_dim)
    for name, controller in robot._controllers.items():
        if name == "arm_0":
            arm = name.replace("arm_", "")
            target_pos, target_orn_axisangle = arm_targets[name]
            current_pos = robot.get_eef_position(arm)
            delta_pos = target_pos - current_pos
            partial_action = th.cat((delta_pos, target_orn_axisangle))
        elif name == "gripper_0":
            partial_action = arm_targets[name]
        else:
            partial_action = controller.compute_no_op_action(robot.get_control_dict())
        
        action[robot.controller_action_idx[name]] = partial_action
    return action


def move_to_initial_position(env):
    # Create initial arm targets based on the provided data
    initial_pos = torch.tensor([0.05, 0.15 ,1.2000])
    initial_rot = torch.tensor([0.0000, 0.0000, -1.5708])
    initial_gripper = torch.tensor([8.0684e-01, 3.4694e-18, 0.0000e+00, 2.5917e-38, 
                                  1.7000e+00, 2.7020e-01, 1.7000e+00, 2.8908e-01,
                                  1.7000e+00, 3.7085e-01, 1.7000e+00, 3.8673e-01])    
    arm_targets = {
        "right": {
            "arm_0": (initial_pos, initial_rot),
            "gripper_0": initial_gripper
        }
    }

    # Generate and execute actions until target position is reached
    for _ in range(60):  # Maximum 100 steps
        action = generate_action(env.robots[0], arm_targets["right"])
        obs, _, _, _, _ = env.step(action)



def convert_to_serializable(obj):
    """将不可序列化的对象转换为可序列化的格式"""
    if isinstance(obj, torch.Tensor):
        return obj.tolist()  # 将tensor转换为list
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_serializable(item) for item in obj)
    return obj


if __name__ == "__main__":
    args = parse_args()
    
    # Set up logging path
    now = datetime.now().strftime("%y%m%d-%H%M%S")
    tag = f"{now}_{args.seed}"
    if args.exp_name:
        tag += f"_{args.exp_name}"
    log_name = os.path.join(args.env_id, args.algo_name, tag)
    log_path = os.path.join(args.output_dir, log_name)

    # 初始化wandb记录
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=log_name.replace(os.path.sep, "__"),
            monitor_gym=True,
            save_code=True,
        )
    
    # 设置tensorboard记录器
    writer = SummaryWriter(log_path)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    
    # 保存参数到json文件
    import json
    with open(f'{log_path}/args.json', 'w') as f:
        # json.dump(vars(args), f, indent=4)
        json.dump(convert_to_json_serializable(args), f, indent=4)

    # TRY NOT TO MODIFY: seeding
    # 设置随机种子
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    # 设置运行设备(GPU/CPU)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    from os.path import dirname as up
    exp_dir = up(up(args.base_policy_ckpt))
    
    envs=load_environment(args)
    env_augmentor = setup_env_augmentor(args)
    # Load base policy
    # 加载基准策略
    base_policy = load_diffusion_policy_model(args.base_policy_ckpt)
    base_policy.eval()
    base_policy.requires_grad_(False)
    # 创建残差actor和critic网络
    res_actor = Actor(envs, args).to(device)
    # 定义输入维度
    proprio_dim = (19 + 19 + 12 + 12 + 3 + 4+3)  # joint_qpos + joint_qvel + gripper_qpos + gripper_qvel + eef_pos + eef_quat
    action_dim = 18  # 机器人的动作维度
    # 根据critic_input模式确定critic的动作输入维度
    if args.critic_input == 'concat':
        action_dim_for_critic = action_dim * 2  # 拼接基础动作和残差动作
    else:
        action_dim_for_critic = action_dim
    # 创建Q网络
    qf1 = SoftQNetwork(proprio_dim, action_dim_for_critic).to(device)
    qf2 = SoftQNetwork(proprio_dim, action_dim_for_critic).to(device)         
    # 创建目标网络
    qf1_target = SoftQNetwork(proprio_dim, action_dim_for_critic).to(device)
    qf2_target = SoftQNetwork(proprio_dim, action_dim_for_critic).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    
    # 设置优化器
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(res_actor.parameters()), lr=args.policy_lr)

    # Automatic entropy tuning
    # 自动调整熵正则化系数
        # 设置SAC的熵正则化参数
    if args.autotune:
        # 使用动作维度作为目标熵（对于连续动作空间的标准做法）
        target_entropy = -action_dim
        # 初始化log_alpha为0，这意味着初始alpha=1.0
        log_sac_alpha = torch.zeros(1, requires_grad=True, device=device)
        sac_alpha = log_sac_alpha.exp().item()
        a_optimizer = optim.Adam([log_sac_alpha], lr=args.q_lr)
    else:
        # 使用固定的alpha值
        sac_alpha = args.sac_alpha
    
    # 创建经验回放缓冲区
    rb = ReplayBuffer(
        args.buffer_size,
        observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(proprio_dim,)),
        action_space=gym.spaces.Box(
            low=-np.inf, high=np.inf, 
            shape=(action_dim_for_critic * 3 if args.critic_input != 'res' else action_dim,)
        ),
        device=device,
        handle_timeout_termination=False,
    )
    # 创建本次运行的结果目录（在 envs = make_env() 之后添加）
    results_dir = setup_logging_dir(args)
    print(f"Saving results to: {results_dir}")
    domain_info = get_domain_info(args)
    domain_info_path = os.path.join(results_dir, "domain_info.json")
    with open(domain_info_path, "w") as f:
        json.dump(convert_to_serializable(domain_info), f, indent=4)

    # TRY NOT TO MODIFY: start the game
    # 开始训练
    # envs.reset(seed=args.seed) # in Gymnasium, seed is given to reset() instead of seed()
    
    obs = envs.get_obs()
    global_step = 0
    global_update = 0
    learning_has_started = False
    num_updates_per_training = int(args.training_freq * args.utd) #16
    result = defaultdict(list)
    # step_in_episodes = np.zeros((args.num_envs,1,1), dtype=np.int32)
    timer = NonOverlappingTimeProfiler()
    ball = envs.scene.object_registry("name", "ball0")
    initial_height = ball.get_position()[2]

    episode_count = 0  # 新增：用于追踪episode数量
    episode_successes = []
    episode_log = []
    res_action_usage = {
        'total_steps': 0,
        'res_steps': 0
    }
    last_saved_count = 0
    results_dir = setup_logging_dir(args)
    print(f"Saving results to: {results_dir}")

    # 加载检查点（如果有）
    if args.resume_path:
        success, loaded_global_step = load_checkpoint(
            args.resume_path,
            res_actor, qf1, qf2, qf1_target, qf2_target,
            log_sac_alpha if args.autotune else None,
            actor_optimizer, q_optimizer,
            a_optimizer if args.autotune else None,
            policy_lr=args.policy_lr,
            q_lr=args.q_lr,
            load_global_step=args.load_global_step  # 使用命令行参数控制是否加载全局步数
        )
        if success:
            global_step = loaded_global_step
            print(f"成功从步骤 {global_step} 恢复训练")

    ################################## 训练主代码 #########################################################
    while global_step < args.total_timesteps: 
        # 从环境中收集样本
        for local_step in range(args.training_freq): 
            global_step += 1
            res_action_usage['total_steps'] += 1
            obs_dict = preprocess_observation(obs, base_policy)
            
            with torch.no_grad():
                base_act_seq = base_policy.predict_action(obs_dict)['action'] # (B, act_horizon, act_dim)
                base_action = base_act_seq.squeeze(0)[0].cpu() #(-1, total_act_dim)
            
            
            
            # 渐进式探索
            res_ratio = min(global_step / args.prog_explore, 1) #prog_explore: 50000 整体探索步数的一半
            enable_res = np.random.rand() < res_ratio
            if(global_step<=args.prog_explore_th):
                enable_res=False
            # ALGO LOGIC: put action logic here
            # 获取残差动作
            proprio = extract_proprio(envs,obs)
            proprio_tensor = torch.FloatTensor(proprio).unsqueeze(0).to(device)
            # print("proprio_obj_tensor:",proprio_obj_tensor.shape)
            # input()
            if not learning_has_started:
                res_action = np.random.uniform(-1, 1, size=action_dim) 
                res_action = np.zeros_like(res_action)
            else:
                actor_input = proprio_tensor if args.actor_input == 'obs' else \
                    torch.cat([proprio_tensor, torch.FloatTensor(base_action).unsqueeze(0).to(device)], dim=1)
                res_action, _, _ = res_actor.get_action(actor_input)
                # print("res_action:",res_action)
                res_action = res_action.detach().cpu().numpy()
                if not enable_res:
                    res_action = np.zeros_like(res_action)

            # 组合基准动作和残差动作
            act_dim=18
            res_act_seq = res_action.reshape(-1, act_dim)
            scaled_res_seq = args.res_scale * res_act_seq # (B, act_horizon, act_dim)
            # 将scaled_res_seq转换为tensor
            scaled_res_seq = torch.FloatTensor(scaled_res_seq).to(device)
            if enable_res and learning_has_started:
                res_action_usage['res_steps'] += 1
                print("global_step:",global_step)
                print("base_act_seq:",base_action)
                print("scaled_res_seq:",scaled_res_seq)
            final_act_seq = base_act_seq + scaled_res_seq
            final_act_seq = final_act_seq[0, 0, :]  # Get first action [18]

            # TRY NOT TO MODIFY: execute the game and log data.
            # 执行动作并获取环境反馈
            final_act_seq = final_act_seq.detach().cpu().numpy()
            next_obs_seq, rewards, terminations, truncations, infos = envs.step(final_act_seq)
            next_obs_seq = envs.get_obs()
            # return obs, reward, terminated, truncated, info
            # print("next_obs_seq:",next_obs_seq)
            # input()
            rewards = rewards - 1.0 # negative reward + bootstrap at truncated yields best results
            # TRY NOT TO MODIFY: record rewards for plotting purposes
            # 记录训练数据
            result = collect_episode_info(infos, result)

            # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
            # 处理经验回放缓冲区的数据存储
            
            episode_success = terminations
            if args.bootstrap_at_done == 'never':
                stop_bootstrap = truncations | terminations # always stop bootstrap when episode ends
            else:
                if args.bootstrap_at_done == 'always':
                    need_final_obs = truncations | terminations # always need final obs when episode ends
                    stop_bootstrap = np.zeros_like(terminations, dtype=bool) # never stop bootstrap
                else: # bootstrap at truncated
                    need_final_obs = truncations & (~terminations) # only need final obs when truncated and not terminated
                    stop_bootstrap = terminations # only stop bootstrap when terminated, don't stop when truncated
            
            # 准备要存储的动作数据
            if args.critic_input == 'res' and args.actor_input == 'obs':
                actions_to_save = res_action
            else:  # sum or concat both need base actions for s and s'
                next_obs_dict = preprocess_observation(next_obs_seq, base_policy)
                with torch.no_grad():
                    base_next_act_seq = base_policy.predict_action(next_obs_dict)['action']
                    base_next_action = base_next_act_seq.squeeze(0)[0].cpu().numpy()
                base_action = base_action.cpu().numpy()
                res_action = res_action.reshape(1, -1) 
                base_action = base_action.reshape(1, -1)  # 变成 (1, action_dim)
                base_next_action = base_next_action.reshape(1, -1)  # 变成 (1, action_dim)
                # print()
                # input()
                actions_to_save = np.concatenate([res_action, base_action, base_next_action])
            # 将数据添加到回放缓冲区
            next_proprio = extract_proprio(envs,next_obs_seq)
            rb.add(proprio, next_proprio, actions_to_save, rewards, stop_bootstrap, infos)
            # print("")
            episode_success = terminations
            # terminations=True
            if terminations or truncations:
                episode_count += 1
                # eval_episode_count += 1
                episode_successes.append(float(episode_success))
                if terminations:
                    serializable_log = convert_to_serializable(env_augmentor.log)
                    episode_log.append((serializable_log, "success"))
                else:  # truncations
                    serializable_log = convert_to_serializable(env_augmentor.log)
                    episode_log.append((serializable_log, "failure"))
                
                  # 记录本episode的成功情况
                # eval_episodes.append(float(episode_success))
                print("############### 重置环境！################")
                env_reset(envs,env_augmentor)
                obs = envs.get_obs()
            else:
                obs = next_obs_seq
        timer.end('collect')
        # print("end collect")
        # print("rb.size() * rb.n_envs:",rb.size() * rb.n_envs)

        if episode_count % args.save_interval == 0 and episode_count > 0 and episode_count != last_saved_count:
            last_saved_count = episode_count 
            # 计算最近20个episodes的成功率
            recent_success_rate = float(np.mean(episode_successes[-args.save_interval:]))
                # 计算并记录残差动作使用率
            res_usage_ratio = res_action_usage['res_steps'] / res_action_usage['total_steps']
            log_filename = os.path.join(results_dir, f"step_{global_step}.log")
            with open(log_filename, "w") as f:
                f.write("Training Results:\n")
                f.write(f"Global Step: {global_step}\n")
                f.write(f"Episode Count: {episode_count}\n")
                f.write(f"Success Rate: {recent_success_rate*100:.1f}% (averaged over last {args.save_interval} episodes)\n")
                f.write(f"Res Action Usage Ratio: {res_usage_ratio*100:.1f}%\n")
                f.write("\nEpisode Logs:\n")
                for log in episode_log[-args.save_interval:]:
                    f.write(f"{str(log)}\n")
                   

            print(f"\nSaved results at episode {episode_count} (step {global_step})")
            print(f"Recent Success Rate: {recent_success_rate*100:.1f}%")

        # ALGO LOGIC: training.
        # 训练开始
        if rb.size() * rb.n_envs < args.learning_starts: #1000步以后开始训练 
            # print("continue collect")
            continue
        else:
            # print("start train")
            learning_has_started = True

        for local_update in range(num_updates_per_training):
            # num_updates_per_training = 16 int(args.training_freq * args.utd) #16
            global_update += 1
            data = rb.sample(args.batch_size) # batch_size=1024 
            # 准备训练数据
            if args.critic_input != 'res' or args.actor_input == 'obs_base_action':
                # 从存储的动作中分离出残差动作和基础动作
                total_act_dim=18
                res_action = data.actions[:, :total_act_dim]
                base_actions = data.actions[:, total_act_dim:total_act_dim*2]
                base_next_actions = data.actions[:, -total_act_dim:].cpu()
            else:
                res_action = data.actions

            #############################################
            # Train agent
            # 训练智能体
            #############################################
            # update the value networks
            # 更新值网络
            with torch.no_grad():
                # 确保输入数据在GPU上
                if args.actor_input == 'obs':
                    actor_input = data.next_observations
                else:
                    base_next_actions = base_next_actions.to(device)  # 移到GPU
                    actor_input = torch.cat([data.next_observations, base_next_actions], dim=1)

                    # 计算残差策略动作
                    # print("actor_input:",actor_input)
                next_state_res_actions, next_state_log_pi, _ = res_actor.get_action(actor_input)
                # print("next_state_res_actions shape:", next_state_res_actions.shape)
                # print("next_state_log_pi shape:", next_state_log_pi.shape)
                # 根据不同模式组合动作
                if args.critic_input == 'res':
                    next_state_actions = next_state_res_actions
                elif args.critic_input == 'sum':
                    scaled_res_actions = (args.res_scale * next_state_res_actions).to(device)
                    base_next_actions = base_next_actions.to(device)
                    next_state_actions = base_next_actions + scaled_res_actions
                else:  # concat
                    base_next_actions = base_next_actions.to(device)
                    next_state_actions = torch.cat([next_state_res_actions, base_next_actions], dim=1)

                # 计算目标Q值
                # print("data.next_observations shape:", data.next_observations.shape)
                # print("next_state_actions shape:", next_state_actions.shape)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                # print("qf1_next_target shape:", qf1_next_target.shape)
                # print("qf2_next_target shape:", qf2_next_target.shape)
                
                # print("next_state_log_pi shape:", next_state_log_pi.shape)
                # print("sac_alpha:",sac_alpha.shape)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - sac_alpha * next_state_log_pi
                
                # 确保rewards和dones在GPU上
                rewards = data.rewards.to(device)
                dones = data.dones.to(device)
                # 检查维度
                # print("rewards shape:", rewards.shape)                    # 应该是 [batch_size]
                # print("dones shape:", dones.shape)                       # 应该是 [batch_size]
                # print("min_qf_next_target shape:", min_qf_next_target.shape)  # 可能是 [batch_size, action_dim] 或其他

                next_q_value = rewards.flatten() + (1 - dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)
                next_q_value = next_q_value.detach()

                # print()

            # 计算当前Q值
            if args.critic_input == 'res':
                current_actions = res_action.to(device)
            elif args.critic_input == 'sum':
                scaled_res_actions = (args.res_scale * res_action).to(device)
                base_actions = base_actions.to(device)
                current_actions = base_actions + scaled_res_actions
            else:  # concat
                res_action = res_action.to(device)
                base_actions = base_actions.to(device)
                current_actions = torch.cat([res_action, base_actions], dim=1)
            
            # 更新Q网络
            observations = data.observations.to(device)
            qf1_a_values = qf1(observations, current_actions).view(-1)
            qf2_a_values = qf2(observations, current_actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss
            # print("global_step:",global_step)
            # print("qf_loss:",qf_loss)

            q_optimizer.zero_grad()
            qf_loss.backward()
            qf1_grad_norm = nn.utils.clip_grad_norm_(qf1.parameters(), args.max_grad_norm)
            qf2_grad_norm = nn.utils.clip_grad_norm_(qf2.parameters(), args.max_grad_norm)
            q_optimizer.step()

            # update the policy network
            # 更新策略网络
            if global_update % args.policy_frequency == 0:
                if args.actor_input == 'obs':
                    actor_input = observations
                else:
                    base_actions = base_actions.to(device)
                    actor_input = torch.cat([observations, base_actions], dim=1)

                res_pi, log_pi, _ = res_actor.get_action(actor_input)
                # res_pi: 残差动作
                # log_pi: 动作的log概率，用于计算熵 
                if args.critic_input == 'res':
                    pi = res_pi
                elif args.critic_input == 'sum':
                    scaled_res_actions = (args.res_scale * res_pi).to(device)
                    base_actions = base_actions.to(device)
                    pi = base_actions + scaled_res_actions
                else:  # concat
                    base_actions = base_actions.to(device)
                    pi = torch.cat([res_pi, base_actions], dim=1)
                # 1. 计算当前策略下的Q值
                qf1_pi = qf1(observations, pi)
                qf2_pi = qf2(observations, pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                actor_loss = ((sac_alpha * log_pi) - min_qf_pi).mean()
                # sac_alpha * log_pi: 熵正则项，鼓励探索
                # -min_qf_pi: 负的Q值，我们要最大化Q值
                # 整体是最小化 熵正则项 - Q值
                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_grad_norm = nn.utils.clip_grad_norm_(res_actor.parameters(), args.max_grad_norm)
                actor_optimizer.step()
            
                # 自动调整熵正则化系数
                if args.autotune:
                    with torch.no_grad():
                        _, log_pi, _ = res_actor.get_action(actor_input)
                    sac_alpha_loss = (-log_sac_alpha * (log_pi + target_entropy)).mean()
                    a_optimizer.zero_grad()
                    sac_alpha_loss.backward()
                    a_optimizer.step()
                    sac_alpha = log_sac_alpha.exp().item()

            # update the target networks
            # 更新目标网络
            if global_update % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
        
        timer.end('train')

        # 记录训练相关数据 ，每100步记录一次log信息
        if (global_step - args.training_freq) // args.log_freq < global_step // args.log_freq:
        #每100步记录一次log信息
            if len(result['return']) > 0:
                for k, v in result.items():
                    writer.add_scalar(f"train/{k}", np.mean(v), global_step)
                result = defaultdict(list)
            
            writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
            writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
            writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
            writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
            writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
            writer.add_scalar("losses/sac_alpha", sac_alpha, global_step)
            writer.add_scalar("losses/qf1_grad_norm", qf1_grad_norm.item(), global_step)
            writer.add_scalar("losses/qf2_grad_norm", qf2_grad_norm.item(), global_step)
            writer.add_scalar("losses/actor_grad_norm", actor_grad_norm.item(), global_step)
            # print("SPS:", int(global_step / (time.time() - start_time)))
            timer.dump_to_writer(writer, global_step)
            if args.autotune:
                writer.add_scalar("losses/sac_alpha_loss", sac_alpha_loss.item(), global_step)
        

        
        # 保存检查点
        if args.save_freq and ( global_step >= args.total_timesteps or \
                (global_step - args.training_freq) // args.save_freq < global_step // args.save_freq):
            os.makedirs(f'{log_path}/checkpoints', exist_ok=True)
            print("############ 保存检查点！##############")
            torch.save({
                'res_actor': res_actor.state_dict(),
                'q1': qf1.state_dict(),
                'q2': qf2.state_dict(),
                'q1_target': qf1_target.state_dict(),
                'q2_target': qf2_target.state_dict(),
                'log_sac_alpha': log_sac_alpha if args.autotune else np.log(args.sac_alpha),
                # 保存优化器状态
                'actor_optimizer': actor_optimizer.state_dict(),
                'q_optimizer': q_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict() if args.autotune else None,
                'global_step': global_step,
            }, f'{log_path}/checkpoints/{global_step}.pt')

    envs.close()
    writer.close()